{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Training a variational autoencoder with 2 layer fully-connected\n",
    "encoder/decoders and gaussian noise distribution.\n",
    "\n",
    "Parag K. Mital, Jan 2016\n",
    "\"\"\"\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from libs.utils import weight_variable, bias_variable, montage_batch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def VAE(input_shape=[None, 784],\n",
    "        n_components_encoder=2048,\n",
    "        n_components_decoder=2048,\n",
    "        n_hidden=2,\n",
    "        debug=False):\n",
    "    # %%\n",
    "    # Input placeholder\n",
    "    if debug:\n",
    "        input_shape = [50, 784]\n",
    "        x = tf.Variable(np.zeros((input_shape), dtype=np.float32))\n",
    "    else:\n",
    "        x = tf.placeholder(tf.float32, input_shape)\n",
    "\n",
    "    activation = tf.nn.softplus\n",
    "\n",
    "    dims = x.get_shape().as_list()\n",
    "    n_features = dims[1]\n",
    "\n",
    "    W_enc1 = weight_variable([n_features, n_components_encoder])\n",
    "    b_enc1 = bias_variable([n_components_encoder])\n",
    "    h_enc1 = activation(tf.matmul(x, W_enc1) + b_enc1)\n",
    "\n",
    "    W_enc2 = weight_variable([n_components_encoder, n_components_encoder])\n",
    "    b_enc2 = bias_variable([n_components_encoder])\n",
    "    h_enc2 = activation(tf.matmul(h_enc1, W_enc2) + b_enc2)\n",
    "\n",
    "    W_enc3 = weight_variable([n_components_encoder, n_components_encoder])\n",
    "    b_enc3 = bias_variable([n_components_encoder])\n",
    "    h_enc3 = activation(tf.matmul(h_enc2, W_enc3) + b_enc3)\n",
    "\n",
    "    W_mu = weight_variable([n_components_encoder, n_hidden])\n",
    "    b_mu = bias_variable([n_hidden])\n",
    "\n",
    "    W_log_sigma = weight_variable([n_components_encoder, n_hidden])\n",
    "    b_log_sigma = bias_variable([n_hidden])\n",
    "\n",
    "    z_mu = tf.matmul(h_enc3, W_mu) + b_mu\n",
    "    z_log_sigma = 0.5 * (tf.matmul(h_enc3, W_log_sigma) + b_log_sigma)\n",
    "\n",
    "    # %%\n",
    "    # Sample from noise distribution p(eps) ~ N(0, 1)\n",
    "    if debug:\n",
    "        epsilon = tf.random_normal(\n",
    "            [dims[0], n_hidden])\n",
    "    else:\n",
    "        epsilon = tf.random_normal(\n",
    "            tf.pack([tf.shape(x)[0], n_hidden]))\n",
    "\n",
    "    # Sample from posterior\n",
    "    z = z_mu + tf.exp(z_log_sigma) * epsilon\n",
    "\n",
    "    W_dec1 = weight_variable([n_hidden, n_components_decoder])\n",
    "    b_dec1 = bias_variable([n_components_decoder])\n",
    "    h_dec1 = activation(tf.matmul(z, W_dec1) + b_dec1)\n",
    "\n",
    "    W_dec2 = weight_variable([n_components_decoder, n_components_decoder])\n",
    "    b_dec2 = bias_variable([n_components_decoder])\n",
    "    h_dec2 = activation(tf.matmul(h_dec1, W_dec2) + b_dec2)\n",
    "\n",
    "    W_dec3 = weight_variable([n_components_decoder, n_components_decoder])\n",
    "    b_dec3 = bias_variable([n_components_decoder])\n",
    "    h_dec3 = activation(tf.matmul(h_dec2, W_dec3) + b_dec3)\n",
    "\n",
    "    W_mu_dec = weight_variable([n_components_decoder, n_features])\n",
    "    b_mu_dec = bias_variable([n_features])\n",
    "    y = tf.nn.tanh(tf.matmul(h_dec3, W_mu_dec) + b_mu_dec)\n",
    "\n",
    "    # p(x|z)\n",
    "    log_px_given_z = -tf.reduce_sum(\n",
    "        x * tf.log(y + 1e-10) +\n",
    "        (1 - x) * tf.log(1 - y + 1e-10), 1)\n",
    "\n",
    "    # d_kl(q(z|x)||p(z))\n",
    "    # Appendix B: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
    "    kl_div = -0.5 * tf.reduce_sum(\n",
    "        1.0 + 2.0 * z_log_sigma - tf.square(z_mu) - tf.exp(2.0 * z_log_sigma),\n",
    "        1)\n",
    "    loss = tf.reduce_mean(log_px_given_z + kl_div)\n",
    "\n",
    "    return {'cost': loss, 'x': x, 'z': z, 'y': y}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "def test_mnist():\n",
    "    \"\"\"Summary\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    name : TYPE\n",
    "        Description\n",
    "    \"\"\"\n",
    "    # %%\n",
    "    import tensorflow as tf\n",
    "    import tensorflow.examples.tutorials.mnist.input_data as input_data\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    # %%\n",
    "    # load MNIST as before\n",
    "    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "    ae = VAE()\n",
    "\n",
    "    # %%\n",
    "    learning_rate = 0.001\n",
    "    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(ae['cost'])\n",
    "\n",
    "    # %%\n",
    "    # We create a session to use the graph\n",
    "    sess = tf.Session()\n",
    "    sess.run(tf.initialize_all_variables())\n",
    "\n",
    "    # %%\n",
    "    # Fit all training data\n",
    "    t_i = 0\n",
    "    batch_size = 100\n",
    "    n_epochs = 50\n",
    "    n_examples = 20\n",
    "    test_xs, _ = mnist.test.next_batch(n_examples)\n",
    "    xs, ys = mnist.test.images, mnist.test.labels\n",
    "    fig_manifold, ax_manifold = plt.subplots(1, 1)\n",
    "    fig_reconstruction, axs_reconstruction = plt.subplots(2, n_examples, figsize=(10, 2))\n",
    "    fig_image_manifold, ax_image_manifold = plt.subplots(1, 1)\n",
    "    for epoch_i in range(n_epochs):\n",
    "        print('--- Epoch', epoch_i)\n",
    "        train_cost = 0\n",
    "        for batch_i in range(mnist.train.num_examples // batch_size):\n",
    "            batch_xs, _ = mnist.train.next_batch(batch_size)\n",
    "            train_cost += sess.run([ae['cost'], optimizer],\n",
    "                                   feed_dict={ae['x']: batch_xs})[0]\n",
    "            if batch_i % 2 == 0:\n",
    "                # %%\n",
    "                # Plot example reconstructions from latent layer\n",
    "                imgs = []\n",
    "                for img_i in np.linspace(-3, 3, n_examples):\n",
    "                    for img_j in np.linspace(-3, 3, n_examples):\n",
    "                        z = np.array([[img_i, img_j]], dtype=np.float32)\n",
    "                        recon = sess.run(ae['y'], feed_dict={ae['z']: z})\n",
    "                        imgs.append(np.reshape(recon, (1, 28, 28, 1)))\n",
    "                imgs_cat = np.concatenate(imgs)\n",
    "                ax_manifold.imshow(montage_batch(imgs_cat))\n",
    "                fig_manifold.savefig('manifold_%08d.png' % t_i)\n",
    "\n",
    "                # %%\n",
    "                # Plot example reconstructions\n",
    "                recon = sess.run(ae['y'], feed_dict={ae['x']: test_xs})\n",
    "                print(recon.shape)\n",
    "                for example_i in range(n_examples):\n",
    "                    axs_reconstruction[0][example_i].imshow(\n",
    "                        np.reshape(test_xs[example_i, :], (28, 28)),\n",
    "                        cmap='gray')\n",
    "                    axs_reconstruction[1][example_i].imshow(\n",
    "                        np.reshape(\n",
    "                            np.reshape(recon[example_i, ...], (784,)),\n",
    "                            (28, 28)),\n",
    "                        cmap='gray')\n",
    "                    axs_reconstruction[0][example_i].axis('off')\n",
    "                    axs_reconstruction[1][example_i].axis('off')\n",
    "                fig_reconstruction.savefig('reconstruction_%08d.png' % t_i)\n",
    "\n",
    "                # %%\n",
    "                # Plot manifold of latent layer\n",
    "                zs = sess.run(ae['z'], feed_dict={ae['x']: xs})\n",
    "                ax_image_manifold.clear()\n",
    "                ax_image_manifold.scatter(zs[:, 0], zs[:, 1],\n",
    "                    c=np.argmax(ys, 1), alpha=0.2)\n",
    "                ax_image_manifold.set_xlim([-6, 6])\n",
    "                ax_image_manifold.set_ylim([-6, 6])\n",
    "                ax_image_manifold.axis('off')\n",
    "                fig_image_manifold.savefig('image_manifold_%08d.png' % t_i)\n",
    "\n",
    "                t_i += 1\n",
    "\n",
    "\n",
    "        print('Train cost:', train_cost /\n",
    "              (mnist.train.num_examples // batch_size))\n",
    "\n",
    "        valid_cost = 0\n",
    "        for batch_i in range(mnist.validation.num_examples // batch_size):\n",
    "            batch_xs, _ = mnist.validation.next_batch(batch_size)\n",
    "            valid_cost += sess.run([ae['cost']],\n",
    "                                   feed_dict={ae['x']: batch_xs})[0]\n",
    "        print('Validation cost:', valid_cost /\n",
    "              (mnist.validation.num_examples // batch_size))\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    test_mnist()"
   ]
  }
 ],
 "metadata": {
  "language": "python"
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
